import jax
import jax.numpy as np
import scalevi.models.models_base as models_base
import scalevi.distributions as dists
import warnings
from typing import Sequence, Iterator, Tuple, List, Dict, Any
import scalevi.utils as utils
import scalevi.nn.nn as nn

def warning_on_one_line(message, category, 
                        filename, lineno, file=None,
                        line=None):
    return f"{category.__name__}:{message}\n"

warnings.formatwarning = warning_on_one_line

def select_id_map(use_test):
    return "test_ids" if use_test else "train_ids"

class RS(models_base.ModelBranched):
    def __init__(self, N_chunk, data):

        self.data = data
        self.D_kid = self.data['genome'].shape[-1]
        self.D_par = self.D_kid + self.D_kid*(self.D_kid+1)//2

        super(RS, self).__init__(N_chunk)

    def eval_parent(self, θ, params, **kwargs):
        return dists.DiagonalNormal(
                        np.zeros(self.D_par),
                        np.ones(self.D_par)).log_prob(θ)
    
    def eval_child(self, θ, params, wi, chunk, **kwargs):
        logp_w = dists.MultivariateNormal(
                θ[:self.D_kid],
                scale_tril=dists.ProximalScaleTransform(1.0).forward(
                                dists.util.vec_to_tril_matrix(
                                    .5*θ[self.D_kid:]))).log_prob(wi)
        idx = np.take(
                self.data[select_id_map(kwargs.get("use_test", False))], 
                chunk, 0)
        x = np.take(self.data['genome'], np.take(self.data['movies'], idx), 0)
        y = np.take(self.data['ratings'], idx)
        logp_y = (idx>=0)*dists.Bernoulli(
                    logits=np.sum(wi*x, axis =-1)).log_prob(y) 
        return np.sum(logp_y) + logp_w

    def eval_child_ll(self, wi, chunk, **kwargs):
        def _chunk_ll(wi, chunk):
            idx = np.take(
                    self.data[select_id_map(kwargs.get("use_test", False))], 
                    chunk, 0)
            x = np.take(self.data['genome'], np.take(self.data['movies'], idx), 0)
            y = np.take(self.data['ratings'], idx)
            logp_y = (idx>=0)*dists.Bernoulli(
                        logits=np.sum(wi*x, axis =-1)).log_prob(y) 
            return np.sum(logp_y)
        if chunk is not None:
            return _chunk_ll(wi, chunk)
        else:
            return np.sum(jax.vmap(_chunk_ll)(wi, np.arange(self.N_chunk)))

    def eval_child_mean_ll(self, wi, chunk, **kwargs):
        def _chunk_ll(wi, chunk):
            idx = np.take(
                    self.data[select_id_map(kwargs.get("use_test", False))], 
                    chunk, 0)
            x = np.take(self.data['genome'], np.take(self.data['movies'], idx), 0)
            y = np.take(self.data['ratings'], idx)
            logp_y = (idx>=0)*dists.Bernoulli(
                        logits=np.sum(wi*x, axis =-1)).log_prob(y) 
            return np.sum(logp_y)/(idx>=0).sum()
        if chunk is not None:
            return _chunk_ll(wi, chunk)
        else:
            assert wi.shape == (self.N_chunk, self.D_kid)
            return np.mean(jax.vmap(_chunk_ll)(wi, np.arange(self.N_chunk)))



    def log_prob(self, z, params, chunk, **kwargs):
        # if chunk == None:
        #     warnings.warn(
        #         f"Use {self.__class__.__name__}.log_prob_batched for "
        #         "more efficient evaluation of the entire dataset.")
        return super(RS, self).log_prob(z, params, chunk, **kwargs)
        
    def log_prob_batched(self, z, params, chunk):
        raise NotImplementedError
        """This is useful when we want a memory efficient evaluation of the entire 
        MovieLens data. vmaping over the chunketh use calculation will waste 
        a lot of computation and memory.
        """        
        θ, w = z
        logp_θ = self.eval_parent(θ, params)

        
        logp_w = dists.MultivariateNormal(
                    θ[:self.D_kid],
                    scale_tril=dists.ProximalScaleTransform(1.0).forward(
                        dists.util.vec_to_tril_matrix(
                            .5*θ[self.D_kid:]))).log_prob(w)
        x = np.take(self.genome, self.movies, 0)
        y = self.ratings
        w = np.take(w, self.users, 0)

        logp_y = dists.Independent(
                    dists.Bernoulli(
                        logits=np.sum(w*x, axis =-1)),
                    1).log_prob(y)
        return logp_θ + np.sum(logp_w) + np.sum(logp_y)

class BranchGaussianWithData(models_base.ModelBranched):
    def __init__(self, N_chunk:int, data:Dict):
        self.data = data
        self.use_mask = data.get('use_mask', False)
        self.D_kid = data['D_kid']
        self.D_par = data['D_par']
        super(BranchGaussianWithData, self).__init__(N_chunk)

    def eval_parent(self, θ:Any, params:dict=None)->float:
        return dists.DiagonalNormal(np.zeros(self.D_par), 1).log_prob(θ)

    def eval_child(self, θ:Any, params:Dict, wi:Any, chunk:int)->float:
        logp_w =  dists.DiagonalNormal(θ, 1).log_prob(wi)
        logp_x = dists.DiagonalNormal(
                            wi, 1).log_prob(self.data['x'][chunk])
        if self.use_mask:
            logp_x *= np.arange(self.data['x'][chunk]) <= self.data['mask'][chunk]
        return logp_w + np.sum(logp_x)

    @classmethod
    def _forward_sample_parent(cls, rng_key, D):
        return {'θ': jax.random.normal(rng_key, (D,))}

    @classmethod
    def _forward_sample_child(cls, rng_key, θ, N):
        return {'w': dists.DiagonalNormal(
                            θ, 1).sample(rng_key, (N,))
                }

    @classmethod
    def _forward_sample_data(cls, rng_key, w, N, M, D):
        return {'x': np.moveaxis(
                        dists.DiagonalNormal(
                            w, 1).sample(rng_key, (M,)), 0, 1)
                }

    @classmethod
    def _forward_sample_mask(cls, rng_key, M, N):
        return {'mask': jax.random.randint(rng_key, (N,), M//2, M)}

    @classmethod
    @jax.partial(jax.jit, static_argnums=(0, 2, 3, 4, 5))
    def _forward_sample(cls, rng_key, M, N, D, use_mask):
        data = {}
        rng_key, rng_skey = jax.random.split(rng_key)
        data.update(cls._forward_sample_parent(rng_skey, D))
        rng_key, rng_skey = jax.random.split(rng_key)
        data.update(cls._forward_sample_child(rng_skey, data['θ'], N))
        rng_key, rng_skey = jax.random.split(rng_key)
        data.update(cls._forward_sample_data(rng_skey, data['w'], N, M, D))
        if use_mask:
            rng_key, rng_skey = jax.random.split(rng_key)
            data.update(cls._forward_sample_mask(rng_skey, M, N))
        return data 

class BranchConditionalGaussian(BranchGaussianWithData):

    def eval_child(self, θ:Any, params:Dict, wi:Any, chunk:int)->float:
        logp_w =  dists.DiagonalNormal(θ, 1).log_prob(wi)
        logp_y = dists.DiagonalNormal(
                        np.dot(self.data['x'][chunk], wi),
                        1).log_prob(self.data['y'][chunk])
        if self.use_mask:
            logp_y *= np.arange(len(self.data['y'][chunk])) <= self.data['mask'][chunk]
        return logp_w + np.sum(logp_y)

    def eval_child_ll(self, wi, chunk, **kwargs):
        def _chunk_ll(wi, chunk):
            logp_y = dists.DiagonalNormal(
                            np.dot(self.data['x'][chunk], wi),
                            1).log_prob(self.data['y'][chunk])
            return np.sum(logp_y)
        if chunk is not None:
            return _chunk_ll(wi, chunk)
        else:
            return np.sum(jax.vmap(_chunk_ll)(wi, np.arange(self.N_chunk)))

    def eval_child_mean_ll(self, wi, chunk, **kwargs):
        def _chunk_ll(wi, chunk):
            logp_y = dists.DiagonalNormal(
                            np.dot(self.data['x'][chunk], wi),
                            1).log_prob(self.data['y'][chunk])
            return np.mean(logp_y)
        if chunk is not None:
            return _chunk_ll(wi, chunk)
        else:
            assert wi.shape == (self.N_chunk, self.D_kid)
            return np.mean(jax.vmap(_chunk_ll)(wi, np.arange(self.N_chunk)))

    @classmethod
    def _forward_sample_data(cls, rng_key, w, N, M, D):
        data = {}
        rng_key, rng_skey = jax.random.split(rng_key)
        data['x'] = jax.random.normal(
                            rng_skey, (N, M, D))
        rng_key, rng_skey = jax.random.split(rng_key)
        data['y'] = dists.DiagonalNormal(
                            utils.mv(
                                data['x'],
                                w),
                            1).sample(rng_skey)
        return data        

class BranchConditionalBernoulli(BranchGaussianWithData):

    def eval_child(self, θ:Any, params:Dict, wi:Any, chunk:int)->float:
        logp_w =  dists.DiagonalNormal(θ, 1).log_prob(wi)
        logp_y = dists.Bernoulli(
                        logits=np.dot(self.data['x'][chunk], wi)
                        ).log_prob(self.data['y'][chunk])
        if self.use_mask:
            logp_y *= np.arange(len(self.data['y'][chunk])) <= self.data['mask'][chunk]

        return logp_w + np.sum(logp_y)

    @classmethod
    def _forward_sample_data(cls, rng_key, w, N, M, D):
        data = {}
        rng_key, rng_skey = jax.random.split(rng_key)
        data['x'] = jax.random.normal(
                            rng_skey, (N, M, D))
        rng_key, rng_skey = jax.random.split(rng_key)
        data['y'] = dists.Bernoulli(
                            logits=utils.mv(
                                data['x'],
                                w)).sample(rng_skey)
        return data        

class BranchGaussian(models_base.ModelBranched):
    def __init__(self, N_chunk, data):
        self.data = data
        self.μ_par = data['μ_par']
        self.L_par = data['L_par']
        self.μ_kid = data['μ_kid']
        self.L_kid = data['L_kid']

        self.D_par = self.μ_par.shape[-1]
        self.D_kid = self.μ_kid.shape[-1]
        # self.gpu_device = gpu_device
        super(BranchGaussian, self).__init__(N_chunk)
    
    def eval_child(self, θ, params, wi, chunk):
        mu = self.μ_kid[chunk]
        L = self.L_kid[chunk]
        return dists.MultivariateNormal(
                        mu+θ,
                        scale_tril=L).log_prob(wi)

    def eval_parent(self, θ, params):
        mu = self.μ_par
        L = self.L_par
        return  dists.MultivariateNormal(
                        mu,
                        scale_tril=L).log_prob(θ)
